/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.transforms.join;



import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.CoderException;

import com.bff.gaia.unified.sdk.coders.StructuredCoder;

import com.bff.gaia.unified.sdk.util.VarInt;

import com.bff.gaia.unified.sdk.util.common.ElementByteSizeObserver;



import java.io.IOException;

import java.io.InputStream;

import java.io.OutputStream;

import java.util.Collections;

import java.util.List;



/** A UnionCoder encodes RawUnionValues. */

public class UnionCoder extends StructuredCoder<RawUnionValue> {

  // TODO: Think about how to integrate this with a schema object (i.e.

  // a tuple of tuple tags).

  /**

   * Builds a union coder with the given list of element coders. This list corresponds to a mapping

   * of union tag to Coder. Union tags start at 0.

   */

  public static UnionCoder of(List<Coder<?>> elementCoders) {

    return new UnionCoder(elementCoders);

  }



  private int getIndexForEncoding(RawUnionValue union) {

    if (union == null) {

      throw new IllegalArgumentException("cannot encode a null tagged union");

    }

    int index = union.getUnionTag();

    if (index < 0 || index >= elementCoders.size()) {

      throw new IllegalArgumentException(

          "union value index " + index + " not in range [0.." + (elementCoders.size() - 1) + "]");

    }

    return index;

  }



  @Override

  public void encode(RawUnionValue union, OutputStream outStream)

      throws IOException, CoderException {

    encode(union, outStream, Context.NESTED);

  }



  @SuppressWarnings("unchecked")

  @Override

  public void encode(RawUnionValue union, OutputStream outStream, Context context)

      throws IOException, CoderException {

    int index = getIndexForEncoding(union);

    // Write out the union tag.

    VarInt.encode(index, outStream);



    // Write out the actual value.

    Coder<Object> coder = (Coder<Object>) elementCoders.get(index);

    coder.encode(union.getValue(), outStream, context);

  }



  @Override

  public RawUnionValue decode(InputStream inStream) throws IOException, CoderException {

    return decode(inStream, Context.NESTED);

  }



  @Override

  public RawUnionValue decode(InputStream inStream, Context context)

      throws IOException, CoderException {

    int index = VarInt.decodeInt(inStream);

    Object value = elementCoders.get(index).decode(inStream, context);

    return new RawUnionValue(index, value);

  }



  @Override

  public List<? extends Coder<?>> getCoderArguments() {

    return Collections.emptyList();

  }



  @Override

  public List<? extends Coder<?>> getComponents() {

    return elementCoders;

  }



  public List<? extends Coder<?>> getElementCoders() {

    return elementCoders;

  }



  /**

   * Since this coder uses elementCoders.get(index) and coders that are known to run in constant

   * time, we defer the return value to that coder.

   */

  @Override

  public boolean isRegisterByteSizeObserverCheap(RawUnionValue union) {

    int index = getIndexForEncoding(union);

    @SuppressWarnings("unchecked")

    Coder<Object> coder = (Coder<Object>) elementCoders.get(index);

    return coder.isRegisterByteSizeObserverCheap(union.getValue());

  }



  /** Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. */

  @Override

  public void registerByteSizeObserver(RawUnionValue union, ElementByteSizeObserver observer)

      throws Exception {

    int index = getIndexForEncoding(union);

    // Write out the union tag.

    observer.update(VarInt.getLength(index));

    // Write out the actual value.

    @SuppressWarnings("unchecked")

    Coder<Object> coder = (Coder<Object>) elementCoders.get(index);

    coder.registerByteSizeObserver(union.getValue(), observer);

  }



  /////////////////////////////////////////////////////////////////////////////



  private final List<Coder<?>> elementCoders;



  private UnionCoder(List<Coder<?>> elementCoders) {

    this.elementCoders = elementCoders;

  }



  @Override

  public void verifyDeterministic() throws NonDeterministicException {

    verifyDeterministic(

        this, "UnionCoder is only deterministic if all element coders are", elementCoders);

  }

}